#!/usr/bin/env python3
# E15_v10 — Deflection (finite deflect-zone + hazard/chord on east-steps; control unchanged)
import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import List, Dict

def utc_timestamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root, subs):
    for s in subs: os.makedirs(os.path.join(root, s), exist_ok=True)

def write_text(p, t):
    with open(p, "w", encoding="utf-8") as f: f.write(t)

def json_dump(p, o):
    with open(p, "w", encoding="utf-8") as f: json.dump(o, f, indent=2, sort_keys=True)

def sha256_file(p):
    import hashlib
    h = hashlib.sha256()
    with open(p, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""): h.update(chunk)
    return h.hexdigest()

def isqrt(n: int) -> int:
    return int(math.isqrt(n))

def step_toward_center(y: int, cy: int) -> int:
    return y-1 if y>cy else (y+1 if y<cy else y)

def trace_path(N:int, cx:int, cy:int, x0:int, x1:int, y0:int, kappa:int, R_defl:int) -> Dict[str, object]:
    x, y = x0, y0
    A = 0
    east_in = 0
    vert_in = 0
    entered = False
    exited  = False
    east_x_inside = []  # x positions AFTER an east step that remain inside zone

    guard = (x1 - x0) * (kappa + 512)
    t = 0
    while x < x1 and t < guard:
        dx = x - cx; dy = y - cy
        r2 = dx*dx + dy*dy
        if r2 == 0: r2 = 1
        r = isqrt(r2)
        inside = (r <= R_defl)
        if inside: entered = True

        if inside:
            A += kappa
            if A >= r2 and y != cy:
                y = step_toward_center(y, cy)
                A -= r2
                vert_in += 1
            else:
                x += 1
                east_in += 1
                # record x after east step if still inside
                dx2 = x - cx; dy2 = y - cy
                r2b = dx2*dx2 + dy2*dy2
                if r2b == 0: r2b = 1
                if isqrt(r2b) <= R_defl:
                    east_x_inside.append(x)
        else:
            x += 1
            if entered and not exited:
                exited = True
        t += 1

    return {
        "entered_zone": entered,
        "exited_zone": exited,
        "east_in": east_in,
        "vert_in": vert_in,
        "east_x_inside": east_x_inside,
        "ticks": t
    }

def r2_y_on_x(xs: List[float], ys: List[float]) -> float:
    n = len(xs)
    if n < 2: return float("nan")
    xb = sum(xs)/n; yb = sum(ys)/n
    num = sum((xs[i]-xb)*(ys[i]-yb) for i in range(n))
    den = sum((xs[i]-xb)*(xs[i]-xb) for i in range(n))
    if den == 0: return float("nan")
    b = num/den
    ss_tot = sum((y - yb)*(y - yb) for y in ys)
    ss_res = sum((ys[i] - (yb + b*(xs[i]-xb)))**2 for i in range(n))
    return 1.0 - (ss_res/ss_tot if ss_tot > 0 else 0.0)

def run_panel(M: dict, N:int, outdir:str, tag:str) -> dict:
    cx = int(M["grid"].get("cx", N//2))
    cy = int(M["grid"].get("cy", N//2))
    x_margin = int(M["source"].get("x_margin", 8))
    x0, x1 = x_margin, N - x_margin

    kappa = int(M["deflect"]["kappa"])
    R_defl = int(M["deflect"]["r_deflect_shells"])
    b_list = sorted([int(b) for b in M["source"]["impact_params_shells"]], reverse=True)

    cert_min = int(M["cert_window"]["b_min_shells"])
    cert_max = int(M["cert_window"]["b_max_shells"])
    mono_tol = float(M["acceptance"].get("mono_tol_rad", 0.001))
    ratio_max_allowed = float(M["acceptance"].get("small_angle_ratio_max", 0.25))

    rows = []
    for b in b_list:
        y0 = cy + b
        if y0 < 0 or y0 >= N:
            y0 = cy - b
            if y0 < 0 or y0 >= N: 
                continue
        r = trace_path(N, cx, cy, x0, x1, y0, kappa, R_defl)

        # chord for this impact parameter
        if abs(b) >= R_defl:
            L_chord = float("nan")
        else:
            L_chord = 2.0*math.sqrt(float(R_defl*R_defl - b*b))

        # hazard integral with fixed-b kernel on east positions only
        hazard = 0.0
        for xe in r["east_x_inside"]:
            dx = float(xe - cx)
            r2_fixed = dx*dx + float(b*b)
            if r2_fixed <= 0: r2_fixed = 1.0
            hazard += (kappa / r2_fixed)

        ratio = (hazard / L_chord) if (L_chord and L_chord>0) else float("nan")
        theta = math.atan(ratio) if not math.isnan(ratio) else float("nan")

        rows.append({
            "b_shells": b,
            "theta_rad": theta,
            "theta_deg": (theta*180.0/math.pi) if not math.isnan(theta) else float("nan"),
            "phi_theta_times_b": (theta*b) if not math.isnan(theta) else float("nan"),
            "east_in": r["east_in"], "vert_in": r["vert_in"],
            "entered_zone": r["entered_zone"], "exited_zone": r["exited_zone"],
            "hazard_I": hazard, "L_chord": L_chord, "ratio_I_over_L": ratio,
            "ticks": r["ticks"]
        })

    # CSV
    mpath = os.path.join(outdir, f"e15_{tag}_per_ray.csv")
    with open(mpath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["b_shells","theta_rad","theta_deg","phi_theta_times_b","east_in","vert_in",
                    "entered_zone","exited_zone","hazard_I","L_chord","ratio_I_over_L","ticks"])
        for row in rows:
            w.writerow([row["b_shells"],
                        f"{row['theta_rad']:.12f}" if not math.isnan(row["theta_rad"]) else "nan",
                        f"{row['theta_deg']:.9f}" if not math.isnan(row["theta_deg"]) else "nan",
                        f"{row['phi_theta_times_b']:.12f}" if not math.isnan(row["phi_theta_times_b"]) else "nan",
                        row["east_in"], row["vert_in"],
                        int(row["entered_zone"]), int(row["exited_zone"]),
                        f"{row['hazard_I']:.9f}",
                        f"{row['L_chord']:.6f}",
                        f"{row['ratio_I_over_L']:.9f}" if row["ratio_I_over_L"]==row["ratio_I_over_L"] else "nan",
                        row["ticks"]])

    # Window & metrics
    W = [r for r in rows if cert_min <= r["b_shells"] <= cert_max and r["entered_zone"] and not math.isnan(r["theta_rad"])]
    if len(W) >= 2:
        mu = sum(r["phi_theta_times_b"] for r in W) / len(W)
        rmse = math.sqrt(sum((r["phi_theta_times_b"] - mu)**2 for r in W) / len(W))
        rel_rmse = (rmse/abs(mu)) if abs(mu) > 0 else float("inf")
        A_hat = mu

        X = [1.0/r["b_shells"] for r in W]
        Y = [r["theta_rad"] for r in W]
        r2 = r2_y_on_x(X, Y)

        W_sorted = sorted(W, key=lambda r: r["b_shells"], reverse=True)
        mono_ok = all(W_sorted[i]["theta_rad"] + mono_tol >= W_sorted[i-1]["theta_rad"]
                      for i in range(1, len(W_sorted)))
        coverage_ok = all(r["exited_zone"] and (r["east_in"] > 0) for r in W_sorted)
        ratio_max = max(r["ratio_I_over_L"] for r in W_sorted)
        small_angle_ok = (ratio_max <= ratio_max_allowed)
    else:
        A_hat = rmse = rel_rmse = r2 = float("nan")
        mono_ok = coverage_ok = small_angle_ok = False
        ratio_max = float("inf")

    return {
        "csv": mpath, "rows": rows, "window_rows": W,
        "A_hat": A_hat, "phi_rel_rmse": rel_rmse, "r2_1overb": r2,
        "mono_ok": mono_ok, "coverage_ok": coverage_ok,
        "small_angle_ok": small_angle_ok, "ratio_max": ratio_max
    }

def mesh_delta(a: float, b: float) -> (float, float):
    absd = abs(a - b)
    denom = max(1e-12, 0.5*(abs(a)+abs(b)))
    return absd, absd/denom

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config","outputs/metrics","outputs/audits","outputs/run_info","outputs/mesh","logs"])

    with open(args.manifest, "r", encoding="utf-8") as f:
        M = json.load(f)
    man_out = os.path.join(root, "config", "manifest_e15_v10.json")
    json_dump(man_out, M)

    write_text(os.path.join(root,"logs","env.txt"),
               "\\n".join([f"utc={utc_timestamp()}",
                          f"os={os.name}", f"cwd={os.getcwd()}",
                          f"python={sys.version.split()[0]}"]))

    N_coarse = int(M["grid"]["N"])
    N_fine   = int(M["mesh"].get("N_fine", 768))

    coarse = run_panel(M, N_coarse, os.path.join(root,"outputs/metrics"), "coarse")
    fine   = run_panel(M, N_fine,   os.path.join(root,"outputs/metrics"), "fine")

    # Acceptance
    rel_rmse_max = float(M["acceptance"].get("flat_rel_rmse_max", 0.08))
    r2_min       = float(M["acceptance"].get("r2_min", 0.96))
    dabs_max     = float(M["mesh"].get("delta_amp_abs_max", 0.12))
    drel_max     = float(M["mesh"].get("delta_amp_rel_max", 0.20))
    require_cov  = bool(M["acceptance"].get("require_coverage_ok", True))
    require_small= bool(M["acceptance"].get("require_small_angle_ok", True))

    flat_ok_c = (coarse["phi_rel_rmse"] <= rel_rmse_max)
    flat_ok_f = (fine["phi_rel_rmse"]   <= rel_rmse_max)
    r2_ok_c   = (coarse["r2_1overb"]    >= r2_min)
    r2_ok_f   = (fine["r2_1overb"]      >= r2_min)
    mono_ok   = bool(coarse["mono_ok"] and fine["mono_ok"])
    cov_ok    = (not require_cov) or (coarse["coverage_ok"] and fine["coverage_ok"])
    small_ok  = (not require_small) or (coarse["small_angle_ok"] and fine["small_angle_ok"])

    dabs, drel = mesh_delta(coarse["A_hat"], fine["A_hat"])
    mesh_ok = (dabs <= dabs_max) and (drel <= drel_max)

    passed = bool(flat_ok_c and flat_ok_f and r2_ok_c and r2_ok_f and mono_ok and mesh_ok and cov_ok and small_ok)

    audit = {
        "sim": "E15_deflection_v10",
        "coarse": {k: coarse[k] for k in ["csv","A_hat","phi_rel_rmse","r2_1overb","mono_ok","coverage_ok","small_angle_ok","ratio_max"]},
        "fine":   {k: fine[k]   for k in ["csv","A_hat","phi_rel_rmse","r2_1overb","mono_ok","coverage_ok","small_angle_ok","ratio_max"]},
        "mesh": {"delta_amp_abs": dabs, "delta_amp_rel": drel, "ok": mesh_ok},
        "accept": {
            "flat_rel_rmse_max": rel_rmse_max, "r2_min": r2_min,
            "delta_amp_abs_max": dabs_max, "delta_amp_rel_max": drel_max,
            "require_coverage_ok": require_cov, "require_small_angle_ok": require_small
        },
        "pass": passed
    }
    json_dump(os.path.join(root,"outputs/audits","e15_audit.json"), audit)

    result_line = ("E15_v10 PASS={p} A_hat_c={ac:.6f} r2_c={r2c:.4f} "
                   "A_hat_f={af:.6f} r2_f={r2f:.4f} ΔA={da:.6f} relΔ={dr:.3f} "
                   "flatRMSE_c={rc:.4f} flatRMSE_f={rf:.4f} small_angle_ok(c/f)={sc}/{sf} mesh_ok={mk}"
                   .format(p=passed,
                           ac=coarse["A_hat"], r2c=coarse["r2_1overb"],
                           af=fine["A_hat"],   r2f=fine["r2_1overb"],
                           da=dabs, dr=drel,
                           rc=coarse["phi_rel_rmse"], rf=fine["phi_rel_rmse"],
                           sc=coarse["small_angle_ok"], sf=fine["small_angle_ok"],
                           mk=mesh_ok))
    write_text(os.path.join(root,"outputs/run_info","result_line.txt"), result_line)
    print(result_line)

if __name__ == "__main__":
    main()
